{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0PD5VPOUr4bs"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"300\">](https://github.com/jeshraghian/snntorch/)\n",
        "[<img src='https://github.com/neuromorphs/tonic/blob/develop/docs/_static/tonic-logo-white.png?raw=true' width=\"200\">](https://github.com/neuromorphs/tonic/)\n",
        "\n",
        "\n",
        "# Training on ST-MNIST with Tonic + snnTorch Tutorial\n",
        "\n",
        "##### By Dylan Louie (djlouie@ucsc.edu), Hannah Cohen Sandler (hcohensa@ucsc.edu), Shatoparba Banerjee (sbaner12@ucsc.edu)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/drive/1P2yQCDmp7TilNrEqj_cBzS7vscIs0L_o?usp=sharing\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iawcPZ7DtDqK"
      },
      "source": [
        "For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)\n",
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W-v36rDBv41L",
        "outputId": "e548b514-1e5f-4d0d-e8f3-fdd55f228fa8"
      },
      "outputs": [],
      "source": [
        "!pip install tonic --quiet\n",
        "!pip install snntorch --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6WWIF2I1v7sA"
      },
      "outputs": [],
      "source": [
        "# tonic imports\n",
        "import tonic\n",
        "import tonic.transforms as transforms  # Not to be mistaken with torchdata.transfroms\n",
        "from tonic import DiskCachedDataset\n",
        "\n",
        "# torch imports\n",
        "import torch\n",
        "from torch.utils.data import random_split\n",
        "from torch.utils.data import DataLoader\n",
        "import torchvision\n",
        "import torch.nn as nn\n",
        "\n",
        "# snntorch imports\n",
        "import snntorch as snn\n",
        "from snntorch import surrogate\n",
        "import snntorch.spikeplot as splt\n",
        "from snntorch import functional as SF\n",
        "from snntorch import utils\n",
        "\n",
        "# other imports\n",
        "import matplotlib.pyplot as plt\n",
        "from IPython.display import HTML\n",
        "from IPython.display import display\n",
        "import numpy as np\n",
        "import torchdata\n",
        "import os\n",
        "from ipywidgets import IntProgress\n",
        "import time\n",
        "import statistics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "EtBwOxn9Ludx"
      },
      "outputs": [],
      "source": [
        "#@title Plotting Settings\n",
        "def print_frame(copy_events):\n",
        "  print('----------------------------')\n",
        "  print(copy_events[0])\n",
        "  print('----------------------------')\n",
        "  print(copy_events[0][0])\n",
        "  print('----------------------------')\n",
        "  print(copy_events[0][0][0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "McXriEu-tJV6"
      },
      "source": [
        "# 1. The ST-MNIST Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wsV-uUeZ6a2A"
      },
      "source": [
        "## 1.1 Introduction\n",
        "\n",
        "The Spiking Tactile-MNIST (ST-MNIST) dataset features handwritten digits (0-9) inscribed by 23 individuals on a 100-taxel biomimetic event-based tactile sensor array. This dataset captures the dynamic pressure changes associated with natural writing. The tactile sensing system, Asynchronously Coded Electronic Skin (ACES), emulates the human peripheral nervous system, transmitting fast-adapting (FA) responses as asynchronous electrical events.\n",
        "\n",
        "More information about the ST-MNIST dataset can be found in the following paper:\n",
        "\n",
        "> <cite> H. H. See, B. Lim, S. Li, H. Yao, W. Cheng, H. Soh, and B. C. K. Tee, \"ST-MNIST - The Spiking Tactile-MNIST Neuromorphic Dataset,\" A PREPRINT, May 2020. [Online]. Available: https://arxiv.org/abs/2005.04319 </cite>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ickp0FA4_nBR"
      },
      "source": [
        "## 1.2 Downloading the ST-MNIST dataset\n",
        "\n",
        "ST-MNIST is in the `MAT` format. Tonic can be used transform this into an event-based format (x, y, t, p).\n",
        "\n",
        "1. Download the compressed dataset by accessing: https://scholarbank.nus.edu.sg/bitstream/10635/168106/2/STMNIST%20dataset%20NUS%20Tee%20Research%20Group.zip\n",
        "\n",
        "2. The zip file is `STMNIST dataset NUS Tee Research Group`. Create a parent folder titled `STMNIST` and place the zip file inside.\n",
        "\n",
        "3. If running in a cloud-based environment, e.g., on Colab, you will need to do this in Google Drive."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DnDx0axoCphC"
      },
      "source": [
        "## 1.3 Mount to Drive\n",
        "You may need to authorize the following to access Google Drive:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c2PcG-B3v9K8",
        "outputId": "7d67801b-2855-40ef-b338-789419af0b50"
      },
      "outputs": [],
      "source": [
        "# Load the Drive helper and mount\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yer7x6Mpwf9q"
      },
      "source": [
        "After executing the cell above, Drive files will be present in \"/content/drive/MyDrive\". You may need to change the `root` file to your own path."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "00hHZOeuv-8k",
        "outputId": "e60a16e7-3baf-4a20-d1cf-bc99dfb18022"
      },
      "outputs": [],
      "source": [
        "root = \"/content/drive/My Drive/\"  # similar to os.path.join('content', 'drive', 'My Drive')\n",
        "os.listdir(os.path.join(root, 'STMNIST')) # confirm the file exists"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OOfqmhKcIOR0"
      },
      "source": [
        "## 1.4 ST-MNIST Using Tonic\n",
        "\n",
        "`Tonic` will be used to convert the dataset into a format compatible with PyTorch/snnTorch. The documentation can be found [here](https://tonic.readthedocs.io/en/latest/generated/tonic.prototype.datasets.STMNIST.html#tonic.prototype.datasets.STMNIST)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4r9zaUjHwAcM"
      },
      "outputs": [],
      "source": [
        "dataset = tonic.prototype.datasets.STMNIST(root=root, keep_compressed = False, shuffle = False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AgwoqxsAMdqP"
      },
      "source": [
        "Tonic formats the STMNIST dataset into `(x, y, t, p)` tuples.\n",
        "* `x` is the position on the x-axis\n",
        "* `y` is the position on the y-axis\n",
        "* `t` is a timestamp\n",
        "* `p` is polarity; +1 if taxel pressed down, 0 if taxel released\n",
        "\n",
        "Each sample also contains the label, which is an integer 0-9 that corresponds to what digit is being drawn.\n",
        "\n",
        "An example of one of the events is shown below:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2nRzg5A0RegL",
        "outputId": "1ffb205a-ae3d-4ac5-e762-8124ab1938de"
      },
      "outputs": [],
      "source": [
        "events, target = next(iter(dataset))\n",
        "print(events[0])\n",
        "print(target)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7mWT1BXPdeuM"
      },
      "source": [
        "The `.ToFrame()` function from `tonic.transforms` transforms events from an (x, y, t, p) tuple to a numpy array matrix."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Alt1gJkWSqjy",
        "outputId": "78706825-055c-4f0d-d11d-1774b8529d50"
      },
      "outputs": [],
      "source": [
        "sensor_size = tuple(tonic.prototype.datasets.STMNIST.sensor_size.values())  # The sensor size for STMNIST is (10, 10, 2)\n",
        "\n",
        "# filter noisy pixels and integrate events into 1ms frames\n",
        "frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),\n",
        "                                      transforms.ToFrame(sensor_size=sensor_size,\n",
        "                                                         time_window=20000)\n",
        "                                     ])\n",
        "\n",
        "transformed_events = frame_transform(events)\n",
        "\n",
        "print_frame(transformed_events)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l3CJDa1LnUzd"
      },
      "source": [
        "## 1.5 Visualizations\n",
        "\n",
        "\n",
        "Using `tonic.utils.plot_animation`, the frame transform, and also some rotation. We can create an animation of the data and visualize this."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kOFkuUfrplsg"
      },
      "outputs": [],
      "source": [
        "# Iterate to a new iteration\n",
        "events, target = next(iter(dataset))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "maDf7TLHmUiw",
        "outputId": "06bbdc7d-68bf-4d93-fe90-80db5631cb4c"
      },
      "outputs": [],
      "source": [
        "frame_transform_tonic_visual = tonic.transforms.ToFrame(\n",
        "    sensor_size=(10, 10, 2),\n",
        "    time_window=10000,\n",
        ")\n",
        "\n",
        "frames = frame_transform_tonic_visual(events)\n",
        "frames = frames / np.max(frames)\n",
        "frames = np.rot90(frames, k=-1, axes=(2, 3))\n",
        "frames = np.flip(frames, axis=3)\n",
        "\n",
        "# Print out the Target\n",
        "print('Animation of ST-MNIST')\n",
        "print('The target label is:',target)\n",
        "animation = tonic.utils.plot_animation(frames)\n",
        "\n",
        "# Display the animation inline in a Jupyter notebook\n",
        "HTML(animation.to_jshtml())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w52aUd2qoyXV"
      },
      "source": [
        "We can also use `snntorch.spikeplot`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 926
        },
        "id": "bPwRVZgqo8EH",
        "outputId": "a6eb9949-9b37-4fbc-a9e4-c06a1ba5183a"
      },
      "outputs": [],
      "source": [
        "frame_transform_snntorch_visual = tonic.transforms.ToFrame(\n",
        "    sensor_size=(10, 10, 2),\n",
        "    time_window=8000,\n",
        ")\n",
        "\n",
        "tran = frame_transform_snntorch_visual(events)\n",
        "tran = np.rot90(tran, k=-1, axes=(2, 3))\n",
        "tran = np.flip(tran, axis=3)\n",
        "tran = torch.from_numpy(tran)\n",
        "\n",
        "tensor1 = tran[:, 0:1, :, :]\n",
        "tensor2 = tran[:, 1:2, :, :]\n",
        "\n",
        "print('Animation of ST-MNIST')\n",
        "print('The target label is:',target)\n",
        "\n",
        "fig, ax = plt.subplots()\n",
        "time_steps = tensor1.size(0)\n",
        "tensor1_plot = tensor1.reshape(time_steps, 10, 10)\n",
        "anim = splt.animator(tensor1_plot, fig, ax, interval=10)\n",
        "\n",
        "display(HTML(anim.to_html5_video()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CzYgPlxWfdm_"
      },
      "source": [
        "There is a total of 6953 recordings in this dataset. The developers of ST-MNIST invited 23 participants to write each 10 digit approx. 30 times each: 23\\*30\\*10 = 6,900."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_v1auSbyepQr",
        "outputId": "aa10b1b1-52f9-43fb-a183-27e5724fc1b0"
      },
      "outputs": [],
      "source": [
        "print(len(dataset))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tlX9jWV0f_az"
      },
      "source": [
        "## 1.6 Lets create a trainset and testset!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hqQzVEHEgSFp"
      },
      "source": [
        "ST-MNIST isn't already seperated into a trainset and testset in Tonic. That means we will have to seperate it manually. In the process of seperating the data we will transform them using `.ToFrame()` as well."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d_6BFKiXJdWU"
      },
      "outputs": [],
      "source": [
        "sensor_size = tonic.prototype.datasets.STMNIST.sensor_size\n",
        "sensor_size = tuple(sensor_size.values())\n",
        "\n",
        "# Define a transform\n",
        "frame_transform = transforms.Compose([transforms.ToFrame(sensor_size=sensor_size, time_window=20000)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iSMhDsHliQk5"
      },
      "source": [
        "The following code reads out the a portion of the dataset, transforms the events using `frame_transform` defined above, and then seperates the data into a trainset and a testset. On top of that, `.ToFrame()` is applied each time. Thus, this code snippet might take a few minutes.\n",
        "\n",
        "For speed, we will just use a subset of the dataset. By default, 640 training samples and 320 testing samples. Feel free to change this if you have more patience than us."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 136,
          "referenced_widgets": [
            "8a99b133e85f4bb695664801a50839c8",
            "fdbee8d3a50947f4ae15bc96922fa93e",
            "65b936fc7fd940db8eab6847ea05d7ca",
            "755b25b8bc1943389b5f9cd53f3c7862",
            "54aed6bf532945ee9c8f756d69714673",
            "156f838174924805b49be2ad66e793b5"
          ]
        },
        "id": "c0qw8uduLpZv",
        "outputId": "415f5422-77bf-459d-9f57-73df1ba4703a"
      },
      "outputs": [],
      "source": [
        "def shorter_transform_STMNIST(data, transform):\n",
        "    short_train_size = 640\n",
        "    short_test_size = 320\n",
        "\n",
        "    train_bar = IntProgress(min=0, max=short_train_size)\n",
        "    test_bar = IntProgress(min=0, max=short_test_size)\n",
        "\n",
        "    testset = []\n",
        "    trainset = []\n",
        "\n",
        "    print('Porting over and transforming the trainset.')\n",
        "    display(train_bar)\n",
        "    for _ in range(short_train_size):\n",
        "        events, target = next(iter(dataset))\n",
        "        events = transform(events)\n",
        "        trainset.append((events, target))\n",
        "        train_bar.value += 1\n",
        "    print('Porting over and transforming the testset.')\n",
        "    display(test_bar)\n",
        "    for _ in range(short_test_size):\n",
        "        events, target = next(iter(dataset))\n",
        "        events = transform(events)\n",
        "        testset.append((events, target))\n",
        "        test_bar.value += 1\n",
        "\n",
        "    return (trainset, testset)\n",
        "\n",
        "start_time = time.time()\n",
        "trainset, testset = shorter_transform_STMNIST(dataset, frame_transform)\n",
        "elapsed_time = time.time() - start_time\n",
        "\n",
        "# Convert elapsed time to minutes, seconds, and milliseconds\n",
        "minutes, seconds = divmod(elapsed_time, 60)\n",
        "seconds, milliseconds = divmod(seconds, 1)\n",
        "milliseconds = round(milliseconds * 1000)\n",
        "\n",
        "# Print the elapsed time\n",
        "print(f\"Elapsed time: {int(minutes)} minutes, {int(seconds)} seconds, {milliseconds} milliseconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "98VVH_HSs-Gh"
      },
      "source": [
        "## 1.6 Dataloading and Batching\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DPxzp1fdFe_X"
      },
      "outputs": [],
      "source": [
        "# Create a DataLoader\n",
        "dataloader = DataLoader(trainset, batch_size=32, shuffle=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yORjaoQAuuY1"
      },
      "source": [
        "For faster dataloading, we can use `DiskCashedDataset(...)` from Tonic.\n",
        "\n",
        "Due to variations in the lengths of event recordings, `tonic.collation.PadTensors()` will be used to prevent irregular tensor shapes. Shorter recordings are padded, ensuring uniform dimensions across all samples in a batch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YaPsfB0ArUgQ"
      },
      "outputs": [],
      "source": [
        "transform = tonic.transforms.Compose([torch.from_numpy])\n",
        "\n",
        "cached_trainset = DiskCachedDataset(trainset, transform=transform, cache_path='./cache/stmnist/train')\n",
        "\n",
        "# no augmentations for the testset\n",
        "cached_testset = DiskCachedDataset(testset, cache_path='./cache/stmnist/test')\n",
        "\n",
        "batch_size = 32\n",
        "trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n",
        "testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0so65S95BDbf",
        "outputId": "4a670cd0-1f61-4e89-b124-e357f526ec86"
      },
      "outputs": [],
      "source": [
        "# Query the shape of a sample: time x batch x dimensions\n",
        "data_tensor, targets = next(iter(trainloader))\n",
        "print(data_tensor.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QDGPdoBUw-ME"
      },
      "source": [
        "## 1.7 Create the Spiking Convolutional Neural Network"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRdPJemVH8uR"
      },
      "source": [
        "Below we have by default a spiking convolutional neural network with the architecture: `10×10-32c4-64c3-MaxPool2d(2)-10o`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W2ewqKLx8mMJ"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
        "\n",
        "# neuron and simulation parameters\n",
        "beta = 0.95\n",
        "\n",
        "# This is the same architecture that was used in the STMNIST Paper\n",
        "scnn_net = nn.Sequential(\n",
        "    nn.Conv2d(2, 32, kernel_size=4),\n",
        "    snn.Leaky(beta=beta, init_hidden=True),\n",
        "    nn.Conv2d(32, 64, kernel_size=3),\n",
        "    snn.Leaky(beta=beta, init_hidden=True),\n",
        "    nn.MaxPool2d(2),\n",
        "    nn.Flatten(),\n",
        "    nn.Linear(64 * 2 * 2, 10),  # Increased size of the linear layer\n",
        "    snn.Leaky(beta=beta, init_hidden=True, output=True)\n",
        ").to(device)\n",
        "\n",
        "optimizer = torch.optim.Adam(scnn_net.parameters(), lr=2e-2, betas=(0.9, 0.999))\n",
        "loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sq_jz3xYxMxO"
      },
      "source": [
        "## 1.8 Define the Forward Pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ydcyDZDt_qH_"
      },
      "outputs": [],
      "source": [
        "def forward_pass(net, data):\n",
        "    spk_rec = []\n",
        "    utils.reset(net)  # resets hidden states for all LIF neurons in net\n",
        "\n",
        "    for step in range(data.size(0)):  # data.size(0) = number of time steps\n",
        "\n",
        "        spk_out, mem_out = net(data[step])\n",
        "        spk_rec.append(spk_out)\n",
        "\n",
        "    return torch.stack(spk_rec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9tPywf6CxWcq"
      },
      "source": [
        "## 1.9 Create and Run the Training Loop\n",
        "\n",
        "This might take a while, so kick back, take a break and eat a snack while this happens; perhaps even count kangaroos to take a nap or do a shoey and get schwasted instead."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lB9lYUP0AUBL",
        "outputId": "dff0092d-4425-4c07-919e-7eb1bcb92700"
      },
      "outputs": [],
      "source": [
        "start_time = time.time()\n",
        "\n",
        "num_epochs = 30\n",
        "\n",
        "loss_hist = []\n",
        "acc_hist = []\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "    for i, (data, targets) in enumerate(iter(trainloader)):\n",
        "        data = data.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        scnn_net.train()\n",
        "        spk_rec = forward_pass(scnn_net, data)\n",
        "        loss_val = loss_fn(spk_rec, targets)\n",
        "\n",
        "        # Gradient calculation + weight update\n",
        "        optimizer.zero_grad()\n",
        "        loss_val.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Store loss history for future plotting\n",
        "        loss_hist.append(loss_val.item())\n",
        "\n",
        "        # Print loss every 4 iterations\n",
        "        if i%4 == 0:\n",
        "            print(f\"Epoch {epoch}, Iteration {i} \\nTrain Loss: {loss_val.item():.2f}\")\n",
        "\n",
        "        # Calculate accuracy rate and then append it to accuracy history\n",
        "        acc = SF.accuracy_rate(spk_rec, targets)\n",
        "        acc_hist.append(acc)\n",
        "\n",
        "        # Print accuracy every 4 iterations\n",
        "        if i%4 == 0:\n",
        "            print(f\"Accuracy: {acc * 100:.2f}%\\n\")\n",
        "\n",
        "end_time = time.time()\n",
        "\n",
        "# Calculate elapsed time\n",
        "elapsed_time = end_time - start_time\n",
        "\n",
        "# Convert elapsed time to minutes, seconds, and milliseconds\n",
        "minutes, seconds = divmod(elapsed_time, 60)\n",
        "seconds, milliseconds = divmod(seconds, 1)\n",
        "milliseconds = round(milliseconds * 1000)\n",
        "\n",
        "# Print the elapsed time\n",
        "print(f\"Elapsed time: {int(minutes)} minutes, {int(seconds)} seconds, {milliseconds} milliseconds\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h6wfP5Jjbf2V"
      },
      "source": [
        "Uncomment the code below if you want to save the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g9Fq4CYvbinS"
      },
      "outputs": [],
      "source": [
        "# torch.save(scnn_net.state_dict(), 'scnn_net.pth')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cEY6Ynbq0JmX"
      },
      "source": [
        "# 2. Results"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yYSkN_kp0Lm0"
      },
      "source": [
        "## 2.1 Plot accuracy history"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 472
        },
        "id": "X0SYWQDJ6qhx",
        "outputId": "a16dd2e9-71ae-4552-ed57-b94550531454"
      },
      "outputs": [],
      "source": [
        "# Plot Loss\n",
        "fig = plt.figure(facecolor=\"w\")\n",
        "plt.plot(acc_hist)\n",
        "plt.title(\"Train Set Accuracy\")\n",
        "plt.xlabel(\"Iteration\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hhx12ZJP0eF7"
      },
      "source": [
        "## 2.2 Evaluate the Network on the Test Set"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BTjjGQHuhQ7i",
        "outputId": "4ce8c530-c92a-4759-926d-bf7fc5a48849"
      },
      "outputs": [],
      "source": [
        "# Make sure your model is in evaluation mode\n",
        "scnn_net.eval()\n",
        "\n",
        "# Initialize variables to store predictions and ground truth labels\n",
        "acc_hist = []\n",
        "\n",
        "# Iterate over batches in the testloader\n",
        "with torch.no_grad():\n",
        "    for data, targets in testloader:\n",
        "        # Move data and targets to the device (GPU or CPU)\n",
        "        data = data.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        # Forward pass\n",
        "        spk_rec = forward_pass(scnn_net, data)\n",
        "\n",
        "        acc = SF.accuracy_rate(spk_rec, targets)\n",
        "        acc_hist.append(acc)\n",
        "\n",
        "        # if i%10 == 0:\n",
        "        # print(f\"Accuracy: {acc * 100:.2f}%\\n\")\n",
        "\n",
        "print(\"The average loss across the testloader is:\", statistics.mean(acc_hist))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1a-oeEzK0Xvt"
      },
      "source": [
        "## 2.3 Visualize Spike Recordings\n",
        "\n",
        "The following visual is a spike count histogram for a single target and single piece of data using the spike recording list."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZJArj6jlXBEs"
      },
      "outputs": [],
      "source": [
        "spk_rec = forward_pass(scnn_net, data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "tJSBM-ctXETI",
        "outputId": "e788d7f4-db1e-4438-fd70-3a9e03884ebc"
      },
      "outputs": [],
      "source": [
        "# Change index to visualize a different sample\n",
        "idx = 0\n",
        "fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))\n",
        "labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']\n",
        "print(f\"The target label is: {targets[idx]}\")\n",
        "\n",
        "#  Plot spike count histogram\n",
        "anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels,\n",
        "                        animate=True, interpolate=1)\n",
        "\n",
        "display(HTML(anim.to_html5_video()))\n",
        "# anim.save(\"spike_bar.mp4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dhkWLN0dYr16"
      },
      "source": [
        "# Congratulations!\n",
        "You trained a Spiking CNN using `snnTorch` and `Tonic` on ST-MNIST!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "156f838174924805b49be2ad66e793b5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "54aed6bf532945ee9c8f756d69714673": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "65b936fc7fd940db8eab6847ea05d7ca": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "755b25b8bc1943389b5f9cd53f3c7862": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "IntProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "IntProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_54aed6bf532945ee9c8f756d69714673",
            "max": 320,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_156f838174924805b49be2ad66e793b5",
            "value": 320
          }
        },
        "8a99b133e85f4bb695664801a50839c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "IntProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "IntProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fdbee8d3a50947f4ae15bc96922fa93e",
            "max": 640,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_65b936fc7fd940db8eab6847ea05d7ca",
            "value": 640
          }
        },
        "fdbee8d3a50947f4ae15bc96922fa93e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
